{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from torch import load, save\n",
    "from numpy import sort\n",
    "from random import shuffle\n",
    "from PIL import Image\n",
    "from matplotlib import pylab as pl\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set user-defined parameters\n",
    "\n",
    "inatpath = './'            # The location of your downloaded 'train_val_images' folder. Make sure it ends in '/'!\n",
    "annopath = './'            # The location of your downloaded 'train_2017_bboxes.json' file. Make sure it ends in '/'!\n",
    "datapath = './'            # Wherever you want your new dataset to appear. Make sure it ends in '/'!\n",
    "catsize_min = 50           # The smallest category allowed in our data set\n",
    "catsize_max = 1000         # The largest category allowed in our data set\n",
    "random_assign = False      # Split categories randomly over the representation and evaluation sets, \n",
    "                           # or use the splits from the paper?\n",
    "if not random_assign:\n",
    "    assert catsize_min==50 and catsize_max==1000, 'The provided splits work only for category sizes between 50 and 1000.'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Built annotation dictionaries\n"
     ]
    }
   ],
   "source": [
    "# Compile bounding box annotations for each image id\n",
    "\n",
    "with open(annopath+'train_2017_bboxes.json') as f:\n",
    "    allinfo = json.load(f)\n",
    "annolist = allinfo['annotations']\n",
    "\n",
    "annodict = dict() # im_id to list of box_ids\n",
    "boxdict = dict() # box_id to box coords\n",
    "catdict = dict() # dict of numerical category codes / labels to corresponding list of image ids\n",
    "for d in annolist:\n",
    "    im = d['image_id']\n",
    "    boxid = d['id']\n",
    "    cat = d['category_id']\n",
    "    \n",
    "    # Add box_id to image entry\n",
    "    if im in annodict:\n",
    "        annodict[im].append(boxid)\n",
    "    else:\n",
    "        annodict[im] = [boxid]\n",
    "        \n",
    "    # Add mapping from box_id to box\n",
    "    boxdict[boxid] = d['bbox']\n",
    "    \n",
    "    # Add image to category set\n",
    "    if cat in catdict:\n",
    "        catdict[cat].add(im)\n",
    "    else:\n",
    "        catdict[cat] = set([im])\n",
    "    \n",
    "print(\"Built annotation dictionaries\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Built path dictionary\n"
     ]
    }
   ],
   "source": [
    "# assemble im_id -> filepath dictionary\n",
    "namelist = allinfo['images']\n",
    "keys = []\n",
    "vals = []\n",
    "for d in namelist:\n",
    "    keys.append(d['id'])\n",
    "    vals.append(inatpath+d['file_name'])\n",
    "pather = dict(zip(keys,vals))\n",
    "\n",
    "print(\"Built path dictionary\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Detected 1135 categories of the desired size\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEWCAYAAABxMXBSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd8X2Xd//HXJ3s3szvdgw5WKaNlU2SJjPsGtwI3inh7K4q3Ci4cvxu3KA4UBZmKiiiIKFZkiaXQQls66KAzXRlt0qRt0iT9/P44V9pv0yRNmvHNeD8fj/PI91zn+p7zufJtv59c1znnOubuiIiItFdCvAMQEZG+RYlDREQ6RIlDREQ6RIlDREQ6RIlDREQ6RIlDREQ6RIlDpJuZ2efN7Jc9eLxUM1tuZkN76pgdYWZfMbOHwushZrbCzFLjHZe0nxKHdCsze6+ZLTCzGjPbamZ/NbMz2vleN7MJ3R1jVzCzy81skZntMrNyM3vGzMYAuPvt7v6hHgznBuAFd98WYhtpZn8IcVWZ2Rtmdu3R7tzM1pvZ+V0RqLtvB54NMUsfocQh3cbMbgZ+ANwODAFGAT8FLo9nXEdiZkkdrD8BeAD4NDAIGEvUzv1dH127fAR4MGb9QWATMBooAD4IbO/oTjv6e+mAh4lilr7C3bVo6fKF6Au0Bri6jTqnAPOASmAr8GMgJWx7AXBgd9jPu0L5pcCi8J5/A8fF7G8G8DpQDfwe+C3w/2K2fxhYA+wAngCGx2xz4GPAamAd8BPge83i/TPwyRbacRWwqI12fgV4KLz+cWhP09IAfCVsGw78ASgLMXyi2e9qAbCL6Ev/+60caxSwF0iKKasBTmgjvsuAZeF3+hwwJWbbeuBzwBKgDvgNUULcG/b72VDvtPB5VAKLgXNi9jEWeD58LnPD7+ChmO1JwB5gdLz/3Wpp3xL3ALT0zwW4KHwpJrVR56TwhZMEjAFWxH4xhy/zCTHrM4BS4FQgEbgmfLGlAinABuAmIBn4D2BfU+IAzgPKwz5SgR8RDefEHmsukA+khy/qLUBC2F4YvtyGtNCOcUAtcAdwLpDVbPuBxNGs/ISQJE4k6v0vBL4c2jIOWAtcGOrOAz4QXmcBp7XyO307sKxZ2T+Al4B3A6OabZtElJzfFn5vnyVKrk0JfD1Roi4G0mPKzo/ZxwigArgktONtYb0oJvbvh9/7WSGBPNQsjiXAZfH+d6ulfYuGqqS7FADl7t7QWgV3X+juL7t7g7uvB34OnN3GPj8M/Nzd57t7o7vfT/RX8GkcTEB3unu9uz8GvBLz3vcB97r7a+5eB9wKzGo6DxF8w913uPted38FqALmhG3vBp7zaEy+eTvWAucQfYH+Dig3s/vMLKu1hphZEfAn4OPu/jpwMtEX7dfcfV/Y5y/CcQHqgQlmVujuNe7+ciu7ziX6Yo51NfAi8CVgXTgXc3LY9i7gL+4+193rge8SJc7ZMe+/0903ufveVo75fuApd3/K3fe7+1yi3tElZjYqtO1L7l7n7i8Q9dyaqw6xSx+gxCHdpQIobGtc3MwmmdmTZrbNzHYRnQspbGOfo4FPm1ll00L0l/DwsGx299hZOzfFvB5O1CMBwN1rQowjWqkPcD/RlyLh54O0IiTAd7p7EXAm0V/WX2iprpklA48Cv3b3R2LaNrxZ2z5PdG4I4Hqi3sGbZvaqmV3aSig7gexmse1091vcfVrY3yLgT2ZmHP572R9+D239XpobDVzdLPYzgGFh/zvdfXdM/Q0t7CObaJhL+gAlDuku84iGb65oo85dwJvARHfPIfqitDbqbwL+z91zY5YMd/8N0TmSEeHLsElxzOstRF9wAJhZJlGvaHNMneZTRT8EXG5mxwNTiHoIR+TurwKPAdNbqfIjor+wv9isbeuatS3b3S8J+1zt7u8BBgPfAh4NbWhuCTCutYTt7uVEvYrhRMNyzX8vRvR7a+v30nx9E/Bgs9gz3f2bRJ9LXrNYR8W+OcQ6gejciPQBShzSLdy9imi8/idmdoWZZZhZspldbGbfDtWyiU721pjZMcBHm+1mO9FYf5NfADea2akWyTSzt5tZNlGiagT+x8ySzOxyovMUTX4NXGdmJ4R7Bm4H5ochstbaUAK8StTT+ENrQzVmdoaZfdjMBof1Y4hOOB82nGRmHyEajntv+Ou+ySvALjP7nJmlm1mimU1vGlIys/ebWVF4T9Nf5o2txLw6tu1m9q2wr6Twu/oosMbdK4iG1t5uZnNCT+jTRMN//27t98Lhn8tDwDvM7MIQd5qZnWNmI919A9Gw1VfNLCVciv2OZvs7BVgf6kpfEO+TLFr690J0bmEB0QnYbcBfgNlh21lEPY4aojH4rwH/innvjUR/sVYC7wxlFxF9mTddifV7IDtsm0k0DFMTyh8jGluP3d9bRFdVPQmMjNl2yIn4mPL3h23nttHG6UTj9tvDsdcT9QqSw/avcPCqqueIvphjr6z6fNg2nOiqpW1EQ04vE05CE305l4b6y4Ar2ojnY8BdMes/IkomNUQn45/k0CunrgSWE53TeR6YFrNtPTEnwkPZ5cDG8Bn8byg7Nbx3RzjGXwgn4omSzIvh+C1dVfUTYq4g09L7FwsfnEi/Y2bzgZ+5+686sY+ziL60x/ihPYReK/SoXgfmuPvWeMfTltBLex440d1r4x2PtI8Sh/QbZnY2sJLostv3AT8Dxh3tl2cYunkEWOzuX+uyQEX6uO66E1QkHiYTjdlnEQ1JXdWJpDGFaIhtMXBdl0Uo0g+oxyEiIh2iq6pERKRD+uVQVWFhoY8ZMybeYYiI9CkLFy4s9+gm1jb1y8QxZswYFixYEO8wRET6FDNr1700GqoSEZEOUeIQEZEOUeIQEZEO6bbEYWb3mlmpmS2NKcs3s7lmtjr8zAvlZmZ3mtkaM1tiZjNi3nNNqL/azK7prnhFRKR9urPHcR/RvEKxbgGecfeJwDNhHeBiYGJYbiCaNRUzywduI5oH5xTgtqZkIyIi8dFticOjB7bsaFZ8OdEzDgg/r4gpf8AjLwO5ZjYMuBCY69HDdXYSTZDWPBmJiEgP6ulzHEOapoAIPweH8hEc+rCYklDWWvlhzOwGM1tgZgvKysq6PHAREYn0lpPjLT28x9soP7zQ/W53n+nuM4uKjnj/Sou2VO7l+39fybry3UeuLCIyQPV04tgehqAIP0tDeQmHPq1tJNGTyVor7xY7du/jzn+uYfX25o9sFhGRJj2dOJ4Amq6MugZ4PKb8g+HqqtOAqjCU9TRwgZnlhZPiF4SybpGTlgzArtqG7jqEiEif121TjpjZb4BzgEIzKyG6OuqbwO/M7HqiJ4hdHao/BVwCrAH2EKaxdvcdZvZ1oie+AXzN3ZufcO8y2WnRr6O6tr67DiEi0ud1W+Jw9/e0smlOC3Wd6HGXLe3nXuDeLgytVQcTh3ocIiKt6S0nx3uFpMQEctKSeGlNebxDERHptZQ4mvnArNHMX7eD0mo9/lhEpCVKHM3MHJMPwMaKPXGORESkd1LiaKY4LwOATTuVOEREWqLE0czIvHQASnbsjXMkIiK9kxJHM2nJiQzOTtXd4yIirVDiaMHJY/P591sV8Q5DRKRXUuJowbEjBrFtVy1l1XXxDkVEpNdR4mjBzNHRIz/+unRrnCMREel9lDhacNLoPJITjS2VupdDRKQ5JY4WmBkFmalU1GioSkSkOSWOVhRlp7K+QldWiYg0p8TRijMmFrJgw07qGhrjHYqISK+ixNGK8UVZuGvqERGR5pQ4WjFteA4ATy7RlVUiIrGUOFoxZVgOU4bl8NrGnfEORUSkV1HiaMO4wkw2V2rOKhGRWEocbSjKTtXd4yIizShxtKEoO5Xq2gb27NOjZEVEmihxtGHqsOgE+aKNlXGORESk91DiaMOkodkArNcluSIiByhxtGFwdipmsG2X5qwSEWmixNGG5MQECjJTKKtW4hARaaLEcQS5GSlU7qmPdxgiIr2GEscR5KYnK3GIiMRQ4jiC3IwU5q2toKZOl+SKiIASxxGdPqEAgEde2RjnSEREegcljiO47vSxjMhN15xVIiKBEkc7zBidx+u6CVBEBFDiaJcZo3LZWlXL1ipNeCgiosTRDjNG5QHw2gb1OkRElDjaYcqwHMxg5fbqeIciIhJ3ShztkJKUQGFWKturdAe5iEhcEoeZfcrMlpnZUjP7jZmlmdlYM5tvZqvN7LdmlhLqpob1NWH7mHjEPHxQGotLKtnXsD8ehxcR6TV6PHGY2QjgE8BMd58OJALvBr4F3OHuE4GdwPXhLdcDO919AnBHqNfjPjBrDG9uq+aBeevjcXgRkV4jXkNVSUC6mSUBGcBW4Dzg0bD9fuCK8PrysE7YPsfMrAdjBeCqk0YyrjCT51eV9fShRUR6lR5PHO6+GfgusJEoYVQBC4FKd2+a16MEGBFejwA2hfc2hPoFzfdrZjeY2QIzW1BW1j1f7qdPKOTF1eW8tKa8W/YvItIXxGOoKo+oFzEWGA5kAhe3UNWb3tLGtoMF7ne7+0x3n1lUVNRV4R7ify+YDMDiEl2WKyIDVzyGqs4H1rl7mbvXA48Bs4HcMHQFMBLYEl6XAMUAYfsgYEfPhhwZlJHM0Jw0Xl4bl8OLiPQK8UgcG4HTzCwjnKuYAywHngWuCnWuAR4Pr58I64Tt/3T3w3ocPeW8KYN5XfNWicgAFo9zHPOJTnK/BrwRYrgb+Bxws5mtITqHcU94yz1AQSi/Gbilp2OONSI3neraBmrrG+MZhohI3CQduUrXc/fbgNuaFa8FTmmhbi1wdU/E1R5DctIA+NvSbVxx4ogj1BYR6X9053gHzTlmMAB/eK0kzpGIiMSHEkcH5WWmcMHUIZTuqot3KCIicaHEcRQG56Sycns1y7ZUxTsUEZEep8RxFN45sxiAxxdtOUJNEZH+R4njKBw3MpdR+Rls02y5IjIAKXEcpeG5aSzdXEUcbykREYkLJY6jdPr4QtaW72b3Pt3PISIDixLHUcrNTAFgz76GI9QUEelflDiOUkZyIgB71eMQkQFGieMoZaREiWOPEoeIDDBKHEcp/UDi0FCViAwsShxHKTM1muZLPQ4RGWiUOI7S0DDZ4fry3XGORESkZylxHKWReekUZqWypETTjojIwKLEcZTMjHGFmWyo2BPvUEREepQSRyeMLshgfYWGqkRkYFHi6IQxhZmUVtdpzioRGVCUODrhjAmFAMxdsT3OkYiI9Bwljk44dsQgkhONzTv3xjsUEZEeo8TRCQkJxpiCTP74eglVe+rjHY6ISI9Q4uikr14+je276njk1Y3xDkVEpEcocXTS7PGF5GemsE43AorIAKHE0QUmD8nmsdc3s1H3dIjIAKDE0QW+fsU06hv389D8DfEORUSk2ylxdIEJg7O5ePpQfjN/I1urdIWViPRvShxd5Mazx9Ow3/nSn5bGOxQRkW6lxNFFjhuZy/tOHcULq8qprddU6yLSfylxdKHTJxSyr3E/r67fEe9QRES6jRJHFzplbD45aUn8bkFJvEMREek2ShxdKDM1ibMnD2beWxXUNWi4SkT6JyWOLnbJ9KGU19SxYP3OeIciItItjpg4zCzTzBLC60lmdpmZJXd/aH3TKWPzAXhg3noaGvfHNxgRkW7Qnh7HC0CamY0AngGuA+7rzqD6soKsVG6aM5Gnl23nxdXl8Q5HRKTLtSdxmLvvAf4D+JG7XwlM7cxBzSzXzB41szfNbIWZzTKzfDOba2arw8+8UNfM7E4zW2NmS8xsRmeO3ROumT0GgLWav0pE+qF2JQ4zmwW8D/hLKEvq5HF/CPzN3Y8BjgdWALcAz7j7RKKezS2h7sXAxLDcANzVyWN3u7yMZLJTk9i0Q3NXiUj/057E8UngVuCP7r7MzMYBzx7tAc0sBzgLuAfA3fe5eyVwOXB/qHY/cEV4fTnwgEdeBnLNbNjRHr8nmBnF+XoeuYj0T0dMHO7+vLtfBvw4rK9190904pjjgDLgV2b2upn90swygSHuvjUcYyswONQfAWyKeX9JKDuEmd1gZgvMbEFZWVknwusax40cxPOrynhYEx+KSD/TnquqZpnZcqLhJMzseDP7aSeOmQTMAO5y9xOB3RwclmoxhBbK/LAC97vdfaa7zywqKupEeF3jtndMY/b4Ar78+DJKdmrISkT6j/YMVf0AuBCoAHD3xURDTUerBChx9/lh/VGiRLK9aQgq/CyNqV8c8/6RwJZOHL9HpKckcuvFU2jc73zn6ZXxDkdEpMu06wZAd9/UrOiob4t2923AJjObHIrmAMuBJ4BrQtk1wOPh9RPAB8PVVacBVU1DWr3d1GE5FOen88yKUs1fJSL9RnsSxyYzmw24maWY2f8Shq064ePAw2a2BDgBuB34JvA2M1sNvC2sAzwFrAXWAL8A/ruTx+4xCQnGQ9efSlF2Kh+4Zz5/X7Yt3iGJiHSauR92uuDQCmaFRJfPnk90vuHvwE3uXtH94R2dmTNn+oIFC+IdxgEbK/Zw5U9fIjstiec+c268wxERaZGZLXT3mUeq156rqsrd/X3uPsTdB7v7+3tz0uiNRhVk8P7TRrO+Yg8VNXXxDkdEpFOOeCOfmd3ZQnEVsMDdH29hm7Tg7MlF/PCZ1Ty+aAv/dcbYeIcjInLU2nOOI43oPMTqsBwH5APXm9kPujG2fmXGqDxG5qVz9wtr2VVbH+9wRESOWnsSxwTgPHf/kbv/iOhcxxTgSuCC7gyuv/nSpVPZtquWxxf1+quJRURa1Z7EMQLIjFnPBIa7eyOgAfsOmHPMYI4Zms3tf1mhcx0i0me1J3F8G1hkZr8ys/uA14HvhmlC/tGdwfU3SYkJfPqCyeytb+SSO19kX4Oe1yEifU97rqq6B5gN/CksZ7j7L919t7t/prsD7G/OnzKYz110DNt31TFvrS5OE5G+p72Pjq0FtgI7gAlm1pkpRwY0M+O608eQlZrE7X9ZwZ59DfEOSUSkQ9ozyeGHiJ4C+DTw1fDzK90bVv+WlpzId68+npXbq/nwAwuo2qurrESk72hPj+Mm4GRgg7ufC5xINC26dMJF04fy3lNH8dKaCq78yUvU1h/19F8iIj2qPYmj1t1rAcws1d3fBCYf4T3SDrdfeSx3vOt41pbv5tpfvRLvcERE2qU9iaPEzHKJTozPNbPH6QPTmvcVV544krcfO4yX1+7QDLoi0ie056qqK9290t2/AnyJ6JGvV7T9LumIb/znsWSkJPL4os3xDkVE5IjadVWVmeWZ2XFANdGDlaZ3a1QDTE5aMpOGZLOuXM8oF5Herz2THH4duJbomRhNd6w5cF73hTXwjMrPYMH6Hbg7Zi09LVdEpHc4YuIA3gmMd/d93R3MQHbauAKeWLyFZVt2MX3EoHiHIyLSqvYMVS0Fcrs7kIHu4ulDSUowPv/HN+IdiohIm9qTOL4BvG5mT5vZE01Ldwc20ORlpnD+lCEsKanivx9eyJGezCgiEi/tGaq6H/gW8AYHz3FIN/jhe07gA/e8wlNvbOPWx97gm/95XLxDEhE5THt6HOXufqe7P+vuzzct3R7ZAJSalMjDHzqVk8fk8fiiLTQ0Kk+LSO/TnsSx0My+YWazzGxG09LtkQ1QyYkJvP+00eytb2TBhp3xDkdE5DDtGao6Mfw8LaZMl+N2oxmj8gC4/r5X+dsnz6I4PyPOEYmIHNSeO8fPbWFR0uhGxfkZ3DRnInvrG7ngjhf46xtb4x2SiMgBrfY4zOzmtt7o7t/v+nCkyafeNok5UwZz2Y9f4stPLOOi6UN1Y6CI9Apt9Tiyj7BINztuZC63X3ksZdV1rNhaHe9wRESANnoc7v7VngxEWnb25CLM4MGXN/CN/zg23uGIiLT70bESJyNy0zlnUhH/WLGdNaXqdYhI/Clx9AGfPH8S7s67736Zvfv0pEARiS8ljj7g+OJcfvzeGZTX7OPWx5ZoOhIRiasjJg4z+2LM69TuDUdac+rYfCYPyeZPi7awc099vMMRkQGs1cRhZp81s1nAVTHF87o/JGmJmfGZC6NHvW/csSfO0YjIQNZWj2MlcDUwzsxeNLO7gQIzm9wzoUlzowqiO8g3VOhJgSISP20ljp3A54E1wDnAnaH8FjP7dzfHJS0ozosSx8YK9ThEJH7aShwXAX8BxgPfB04Bdrv7de4+u7MHNrNEM3vdzJ4M62PNbL6ZrTaz35pZSihPDetrwvYxnT12X5WeksiI3HQeeXUTZdV18Q5HRAaoVhOHu3/e3ecA64GHiG4WLDKzf5nZn7vg2DcBK2LWvwXc4e4TiXo714fy64Gd7j4BuCPUG7BOn1DA5sq9nPx//+CGBxZQUaMEIiI9qz2X4z7t7q+6+91AibufAVzXmYOa2Ujg7cAvw7oRzbb7aKhyP3BFeH15WCdsn2MDeNKmb/7HcTzwX6fw3lNH8ezKUs797nMs1PTrItKD2jM77mdjVq8NZeWdPO4PgM9y8ImCBUCluzeE9RJgRHg9AtgUjtsAVIX6hzCzG8xsgZktKCsr62R4vVdCgnHWpCJuv/JY7r/uFFKSEvjPu/7NtqraeIcmIgNEh24AdPfFnT2gmV0KlLr7wtjilg7Xjm2xsd3t7jPdfWZRUVFnw+wTZk8o5EfviZ6p9fD8DXGORkQGinjcOX46cJmZrQceIRqi+gGQa2ZNky6OBLaE1yVAMUDYPgjY0ZMB92azxhdw8fSh/Oifa3hyyRb279dd5SLSvXo8cbj7re4+0t3HAO8G/unu7wOe5eDNhtcAj4fXT4R1wvZ/uubcOMQXL51KenIi//Pr1znjW/9k+ZZd8Q5JRPqx3jRX1eeAm81sDdE5jHtC+T1ENx6uAW4GbolTfL3WiNx05t58Fl95x1Qq99ZzxU9eYmvV3niHJSL9lPXHP95nzpzpCxYsiHcYcfGP5dv50AMLGFeUyf3XnaLnlYtIu5nZQnefeaR6vanHIV3g/KlD+MyFk1lbtpsPP7CAugZNwy4iXUuJox/62LkT+PTbJvHmtmrO++7zLN5UGe+QRKQfUeLopz527gRue8dUNlfu5cuPL413OCLSjyhx9FMJCcZ1p4/lPacUs7ikir8t3RrvkESkn1Di6OduuWgKYwsz+cRvFvHi6v57R72I9Bwljn5uUEYyf/zv2eSkJ3HjgwvZVaunB4pI5yhxDAC5GSlcOG0ou/c18rlHl1DfuP/IbxIRaYUSxwDxf1cey7Wzx/DXpds469vP8vwqDVuJyNFR4hhAbnvHVL539fHsrmvgxgcX8sC89XoglIh0mO4cH4DWltVw8+8Wsyjc3zFteA6/uu5kBmenxTkyEYkn3TkurRpXlMVjH53NvdfO5Mazx7N6ew1X3TWPVdur4x2aiPQBShwDVEKCcd4xQ7jl4mN4+MOnUlZdxwV3vMA3nlpB5Z598Q5PRHoxJQ7h5DH5/PrDpzJzdB4/f2EtH33oNV15JSKtUuIQAE4clcejH53NB2eNZt7aCi6981/86qV1NOrBUCLSjBKHHOK2d0zj21cdR8XuOr765+X87Pm36I8XUIjI0VPikEMkJhjvnFnMq184nxOKc/nO0yv54L2vsL58d7xDE5FeQolDWmRmPPyhU7nhrHG8uLqcc777HLc+9gYbK/bEOzQRiTPdxyFHtHzLLr7395U882YpAJceN4yvXDaNwqzUOEcmIl2pvfdxKHFIu7g7izZVctdzb/H35dvJSk3i3ScXc/LYfOYcM5ikRHVeRfo6JQ4ljm6zbEsVX31iOa+s3wHAyLx0Hr1xNkMH6c5zkb5Md45Lt5k2fBC/u3EWi778Nm69+BhKd9XxoQdepaJG816JDARKHHLUcjNS+MjZ47nzPSeyYms1s775T17fuDPeYYlIN1PikE67aPpQ7rvuZPY17Of7c1fprnORfk6JQ7rEmROL+Ei4dHfG1+by2GslunFQpJ9S4pAuc8vFx3D7lccyOCeVm3+3mHO++xy/fXUju+sa4h2aiHQhXVUlXW5fw37u+McqHnuthO276shOTeIDs0bz4TPHkZeZEu/wRKQVuhxXiSPuGvc7Ty/bxn0vreeV9TvISEnk7ccO4+JjhzJ7fCFpyYnxDlFEYihxKHH0Kq9v3Mkd/1jNi6vLcIfi/HTmfupsJQ+RXkT3cUivcuKoPB74r1NY9KULuPHs8WzasZczvvVP/rZ0a7xDE5EOUuKQHjUoI5lbLj6Ge6+dCRg3PvQap97+Dx58eUO8QxORdlLikLg475ghPPeZc/jk+RPZU9fIl/60lG88tYIG3QMi0uspcUjcZKUm8cnzJ/HMp8/mnMlF/PyFtZz3vedZvmVXvEMTkTYocUjcDc5J495rTuY7Vx1HXUMjl9z5Iqfd/gz3/msdtfWN8Q5PRJrp8cRhZsVm9qyZrTCzZWZ2UyjPN7O5ZrY6/MwL5WZmd5rZGjNbYmYzejpm6X4JCcbVM4t58uNn8qEzxtKwfz9fe3I5J319Lr94YS2l1bXxDlFEgh6/HNfMhgHD3P01M8sGFgJXANcCO9z9m2Z2C5Dn7p8zs0uAjwOXAKcCP3T3U9s6hi7H7fvcnWdWlHLbE8vYXLkXM7js+OGcObGIkXnpnFCcq0t5RbpYey/HTeqJYGK5+1Zga3hdbWYrgBHA5cA5odr9wHPA50L5Ax5luJfNLNfMhoX9SD9lZpw/dQhnTSpi3toK7npuDY8v2sLji7YAMDQnjS9dOpWLpw8lIcHiHK3IwNLjiSOWmY0BTgTmA0OakoG7bzWzwaHaCGBTzNtKQtkhicPMbgBuABg1alS3xi09JyUpgbMnFXH2pCJq6xtZtmUX9760judXlvGxX78GwAnFuVxy7FBGF2QyYXAW+RkpmtpEpBvFLXGYWRbwB+CT7r7LrNW/GlvacNj4mrvfDdwN0VBVV8UpvUdaciInjc7jpNF51NY38sSiLazYtovHF23h9qfePKRuQWYK184ew4zRecwaV6BeiUgXikviMLNkoqTxsLs/Foq3Nw1BhfMgpaG8BCiOeftIYEvPRSu9UVpyIu88Ofpn8eVLp1Kycy+rtldTtbee1aU1LN1cxffmrgJgcHYq150+lmtmjyYjJa6dbJF+ocf/F1m94VlsAAAQMUlEQVTUtbgHWOHu34/Z9ARwDfDN8PPxmPL/MbNHiE6OV+n8hsQyM4rzMyjOzzhQ5u6sLq3hlXU7eOTVjXzrb2/yvb+v5MLpQ7lw2lBOH19AQVZqHKMW6bvicVXVGcCLwBtA023Cnyc6z/E7YBSwEbja3XeERPNj4CJgD3Cdu7d5yZSuqpLmnl1ZysMvb2TeW+Xs3hfdGzK2MJNzJw/mcxdPJjVJV2iJaHZcJQ5pwd59jby4uoyFG3cyd9l21pbv5sJpQ7jqpGKG56YxbfigeIcoEjdKHEoc0g6ffXQxv1tQcmB9xqhcZo0voCgrlTMmFjJhcHYcoxPpWUocShzSDvv3O6tKq6mt389f39jKX97YSsnOvQe2Hz9yEMcX53Lq2ALOnFRITlpyHKMV6V5KHEoccpRq6xtZX7Gb3y8o4flVZawprTmw7cyJhUwfMYhJQ7I4a2KRTrBLv6LEocQhXaS6tp55b1Xw3Koy/r5sG+U1+wBITUrgzImFnD6hkOK8DCYMzmJYbppOtEufpcShxCHdZOfufby5rZqHXt7AcytLD1ylBZCbkcy7Tx7FBdOGMHVYjubTkj5FiUOJQ3qAu/NWWQ3bqupYXVrNn17fzOKSKgDMokt+z5hQyLThOUwYnE1xXjqDc9LiHLVIy5Q4lDgkTrZV1fL35dtYV76bJSVVLNyw85DtackJTB2Ww+ShOeSkJzFxcDaD0pM5e1IRKUl6RI7ET6+dHVekvxs6KI0PzhpzYH1XbT3ry3ezpXIva0preKtsN0s3V/HUG1up2lt/oF5igjF1WA4zx+SRk5bMcSMHcdq4AjJT9d9Uehf9ixTpZlESyOW4kbmHbavaU8+OPftYvKmS+esqeGXdDh6Yt4HG/dFIgBmcWJzLqeMKyE1PZurwHE4clUdmSiJtTAwq0q00VCXSC1Xu2ccLq8t5aXU589ZWsHHHnkO2HzdyEJ++YDJnTSxUApEuo3McShzSj9TWN1K1t55/rS5nTVkNv3llI5V7omGu3IxkLjl2GDNG5XHBtCG6SVGOmhKHEof0Y9W19fzx9c2UVdfx8toKXttYeWB4Ky8jmYlDspkxKo/Z4ws4YVSukom0ixKHEocMILX1jby0ppz563awpXIvy7bsYl357gPbR+VnMHN0HhOGZGEYg9Kjk+9ThuWQqIdcSaCrqkQGkLTkROZMGcKcKUMOlJVV1/Hvt8pZurmKNzZX8eclW6hvPPQPxezUJIbnpjN+cCaDs9OYOCSLCUVZpKckMmlItm5glBYpcYj0U0XZqVx+wgguP2EEAPWN+2nc77jDqu3VvLZxJ4s2VbJ9Vy2LNlaypar2sH1MG57DtOE5DM5OY1xRJvmZKYzKzyA7LZmibM3TNVApcYgMEMmJCTR1II4vzuX44kMvD95d10DJzr2U7NxDdW0D89dV8MbmKv62dBu7ahsO2192ahLF+RmMKcxgbGEm6cmJTB2ew7EjcpVU+jmd4xCRI6qoqaNybz2rt0dT0K8tq6G0uo7VpTWs2lbNnvrGAyfnAUbmpTNjVB6ThmSRk57MhKIsRhdmkpeRrOe+92I6xyEiXaYgK5WCrFTGF2W1WmdbVS2LSypZurmKl9dW8MyK7TyxeMth9YbkpHLsiEHkZqQwNCeNacNzGJGXzqj8DHIzUrqzGdJFlDhEpEsMHZTG0EFDuXDaUAAa9zv1jftZV76bbbtqeau0hso99SzdUsWb26qprW88MEV9k5F56Zw5sZDUpESmDMvm+OJcUhITyM1IIT9TSaW3UOIQkW6RmGAkJiQyZVgOU4blcO7kwYfVqaipY/nWXeza28DyrVW8sKqcp97Yxu66Bhr2HzqMPnFwFqeNKyA3I5nxRVlkpCRSkJVKcV46uRkpmiCyB+kch4j0OvWN+1mwfiel1bXsd2fZ5l3MW1vBym3VhyUUiB6qNXloNuMKM5k4JJsEM6YOz2Hm6DwSE0yXFbeTbgBU4hDpl/bsa2Bt2W4a9zurtldTXdvAiq27WLFtF+vKdh/yYC2IJoocnZ9BWnIi4wdnkZ6cSHpyIhMGZ5GWnMCEwdmkJiVwzNBskhIHdq9FJ8dFpF/KSEli+ohBAIddUtx0XmXvvkb+vnwblXvq2VpVS2l1LRU1+1i0sRJ3Z9uuWlrouJCTlsTYoizSkxMOPCdlXFEmZjB8UPohlxkPSk8esM+cV+IQkX6j6bxKWnIi7zp5VKv19u5rZPe+BrZWHkwqG3fsYU1pDXvqG9lQsZvlW3a1eP9KrOzUJAqyUijOzzhQZmaMK8wkMzWRRLMDvRyAMYWZ5KZH84blpCf32SE0JQ4RGXDSUxJJT0mkMCsVGNRqvYqaOqprG6htaGRNac2Be1XcYXVpNEy2ensNu+sOJpjKPfX8e005QIvnY5okJ1o4/kEj89IpyEwlIyUaVksIU+anJycwcUg2TTPopydHU8KYQYL1/DkcJQ4RkVY03b8CcMzQnA6/v7q2ng0V0bNU9uxr5K2yGva709DorC6tZl/D/gN199bvZ01pDVV7a9hQsYe6mG1HUpyfTlpSlDzOmVzEF94+tcOxdoQSh4hIN8lOSz5wPgbglLH57XpfQ+P+Q3orG3fsYcfug/e8bKjYfeB5LGXVdWyp2ntg25CctM6GfURKHCIivUxSYgJJMaNPk4ZkH7L9tHEFPRzRoQb2tWciItJhShwiItIhShwiItIhShwiItIhShwiItIhShwiItIhShwiItIhShwiItIh/XJadTMrAzZ0YheFQHkXhdNb9Mc2gdrVl/THNkH/atdody86UqV+mTg6y8wWtGdO+r6kP7YJ1K6+pD+2Cfpvu9qioSoREekQJQ4REekQJY6W3R3vALpBf2wTqF19SX9sE/TfdrVK5zhERKRD1OMQEZEOUeIQEZEOUeKIYWYXmdlKM1tjZrfEO56OMLNiM3vWzFaY2TIzuymU55vZXDNbHX7mhXIzsztDW5eY2Yz4tqB1ZpZoZq+b2ZNhfayZzQ9t+q2ZpYTy1LC+JmwfE8+422JmuWb2qJm9GT6zWf3ks/pU+Pe31Mx+Y2ZpffHzMrN7zazUzJbGlHX48zGza0L91WZ2TTza0h2UOAIzSwR+AlwMTAXeY2bd++DertUAfNrdpwCnAR8L8d8CPOPuE4FnwjpE7ZwYlhuAu3o+5Ha7CVgRs/4t4I7Qpp3A9aH8emCnu08A7gj1eqsfAn9z92OA44na16c/KzMbAXwCmOnu04FE4N30zc/rPuCiZmUd+nzMLB+4DTgVOAW4rSnZ9HnuriW6QGAW8HTM+q3ArfGOqxPteRx4G7ASGBbKhgErw+ufA++JqX+gXm9agJFE/0nPA54EjOgu3aTmnxvwNDArvE4K9SzebWihTTnAuuax9YPPagSwCcgPv/8ngQv76ucFjAGWHu3nA7wH+HlM+SH1+vKiHsdBTf/om5SEsj4ndPlPBOYDQ9x9K0D4OThU6yvt/QHwWWB/WC8AKt29IazHxn2gTWF7Vajf24wDyoBfhSG4X5pZJn38s3L3zcB3gY3AVqLf/0L6/ufVpKOfT5/43I6GEsdB1kJZn7tW2cyygD8An3T3XW1VbaGsV7XXzC4FSt19YWxxC1W9Hdt6kyRgBnCXu58I7ObgsEdL+kS7wjDM5cBYYDiQSTSM01xf+7yOpLV29Jf2HUaJ46ASoDhmfSSwJU6xHBUzSyZKGg+7+2OheLuZDQvbhwGlobwvtPd04DIzWw88QjRc9QMg18ySQp3YuA+0KWwfBOzoyYDbqQQocff5Yf1RokTSlz8rgPOBde5e5u71wGPAbPr+59Wko59PX/ncOkyJ46BXgYnhCpAUopN6T8Q5pnYzMwPuAVa4+/djNj0BNF3NcQ3RuY+m8g+GK0JOA6qauuG9hbvf6u4j3X0M0efxT3d/H/AscFWo1rxNTW29KtTvdX/hufs2YJOZTQ5Fc4Dl9OHPKtgInGZmGeHfY1O7+vTnFaOjn8/TwAVmlhd6YxeEsr4v3idZetMCXAKsAt4CvhDveDoY+xlE3eAlwKKwXEI0ZvwMsDr8zA/1jegqsreAN4iuhIl7O9po3znAk+H1OOAVYA3weyA1lKeF9TVh+7h4x91Ge04AFoTP609AXn/4rICvAm8CS4EHgdS++HkBvyE6T1NP1HO4/mg+H+C/QvvWANfFu11dtWjKERER6RANVYmISIcocYiISIcocYiISIcocYiISIcocYiISIcocUifZ2ZfCDOyLjGzRWZ2agfff62ZDe/ge8bEzpzabNskM3sqzJa6wsx+Z2ZDjrCv93bk+EfLzL5mZuf3xLGk/0o6chWR3svMZgGXAjPcvc7MCoGUDrw/EbiW6L6DTt/Va2ZpwF+Am939z6HsXKAI2N7K28YA7wV+3dnjHyG2RHf/cnceQwYG9TikrxsGlLt7HYC7l7v7FgAzmxMmEXwjPF8hNZSvN7Mvm9m/iGYwnQk8HHor6WZ2kpk9b2YLzezpmGkmTjKzxWY2D/hYK/G8F5jXlDRCTM+6+9LQs3jRzF4Ly+xQ5ZvAmeH4n7Lo+SPfMbNXQy/qI+H4CWb209C7ejL0aq7qQFuvNrP7Yt7TWjs/YWbLw7Ef6aoPSvqReN+BqEVLZxYgi+gu+VXAT4GzQ3ka0cykk8L6A0QTPwKsBz4bs4/nCHf7AsnAv4GisP4u4N7weknM/r9DzJTbMfv6PnBTK7FmAGnh9URgQXh9DuGu+LB+A/DF8DqV6A7zsUTTcjxF9AffUKJnW1zVwbbeF97TVju3cPDu7tx4f8Zaet+iHof0ae5eA5xE9GVbBvzWzK4FJhNNuLcqVL0fOCvmrb9tZZeTgenAXDNbBHwRGGlmg4i+RJ8P9R48inCTgV+Y2RtEU2209qCwC4jmPlpENDV+AVGiOQP4vbvv92i+q2djYu5oW1tsZ9i2hKgH9n6iB4SJHELnOKTPc/dGol7Dc+FL+RqiXkhbdrdSbsAyd591SKFZLu2bEnsZcHYr2z5FdJ7jeKJeQ20bMXzc3Q+ZEM/M3t5G/ba01NYW2xm8nSjxXAZ8ycym+cHnaYioxyF9m5lNNrOJMUUnABuIJtobY2YTQvkHgOebvz+oBrLD65VAUTjpjpklhy/OSqDKzM4I9d7Xyr5+DcyO/ZK36Fn2xxJNG77V3feHeBJbOD5EM6h+1KJp8puu0soE/gX8ZzjXMYRoiIsOtrVJi+00swSg2N2fJXqAVi7RcKDIAepxSF+XBfwo9AgaiGYhvcHda83sOuD3Fj3r4VXgZ63s4z7gZ2a2l+jRplcBd4bhqSSiZ4AsA64D7jWzPbQyPba777XoAVQ/MLMfEM2uuoTouek/Bf5gZlcTDTM19QSWAA1mtjjE8kOiK61eMzMjGoK7guhZK3OIrgBbRTSMVdXBtjbFuS+cJG/ezlXAQ6HMiJ4VXtnWvmTg0ey4In2ImWW5e42ZFRBNRX56ON8h0mPU4xDpW54MvasU4OtKGhIP6nGIiEiH6OS4iIh0iBKHiIh0iBKHiIh0iBKHiIh0iBKHiIh0yP8HiBCyphN2lZ8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Visualize the categories that meet the given size requirements\n",
    "\n",
    "catsizes = sort([(len(catdict[c])) for c in catdict if len(catdict[c]) >= catsize_min and len(catdict[c]) <= catsize_max])\n",
    "print('Detected %d categories of the desired size' % len(catsizes))\n",
    "pl.figure()\n",
    "pl.plot(catsizes[::-1])\n",
    "pl.title('Category Sizes (Sorted)')\n",
    "pl.ylabel('# Images')\n",
    "pl.xlabel('Sorted Categories')\n",
    "pl.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial size: 2854 categories\n",
      "Final size: 1135 categories\n",
      "\n",
      "Supercategory distributions:\n",
      "Insecta: 332\n",
      "Animalia: 22\n",
      "Aves: 478\n",
      "Reptilia: 124\n",
      "Amphibia: 50\n",
      "Mammalia: 67\n",
      "Actinopterygii: 7\n",
      "Arachnida: 22\n",
      "Mollusca: 33\n"
     ]
    }
   ],
   "source": [
    "# Pare down the category dictionary to the desired size\n",
    "\n",
    "print('Initial size: %d categories' % len(list(catdict.keys())))\n",
    "clist = list(catdict.keys())\n",
    "for c in clist:\n",
    "    if len(catdict[c]) < catsize_min or len(catdict[c]) > catsize_max:\n",
    "        catdict.pop(c)\n",
    "print('Final size: %d categories' % len(list(catdict.keys())))\n",
    "\n",
    "supercat = dict()\n",
    "for d in allinfo['categories']:\n",
    "    catid = d['id']\n",
    "    if catid in catdict:\n",
    "        sc = d['supercategory']\n",
    "        if sc in supercat:\n",
    "            supercat[sc].append(catid)\n",
    "        else:\n",
    "            supercat[sc] = [catid,]\n",
    "print('\\nSupercategory distributions:')\n",
    "for sc in supercat:\n",
    "    print(sc+':', len(supercat[sc]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Category splits assigned. \n",
      "Representation set has 908 of 1135 concepts, evaluation has 227.\n"
     ]
    }
   ],
   "source": [
    "# Create category splits\n",
    "\n",
    "if random_assign:\n",
    "    catlist = list(catdict.keys())\n",
    "    shuffle(catlist)\n",
    "    testlen = len(catlist)//5\n",
    "    testcatlist = catlist[:testlen]\n",
    "    traincatlist = catlist[testlen:]\n",
    "else:\n",
    "    traincatlist = load('helpful_files/traincatlist.pth')\n",
    "    testcatlist = load('helpful_files/testcatlist.pth')\n",
    "\n",
    "print('Category splits assigned. \\nRepresentation set has %d of %d concepts, evaluation has %d.' \n",
    "      % (len(traincatlist), len(list(catdict.keys())), len(testcatlist)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categories completed:\n",
      "50\n",
      "100\n",
      "150\n",
      "200\n",
      "250\n",
      "300\n",
      "350\n",
      "400\n",
      "450\n",
      "500\n",
      "550\n",
      "600\n",
      "650\n",
      "700\n",
      "750\n",
      "800\n",
      "850\n",
      "900\n",
      "Representation set complete!\n"
     ]
    }
   ],
   "source": [
    "# Build representation set\n",
    "\n",
    "boxdict_smaller = dict()\n",
    "count = 0\n",
    "catlist = traincatlist\n",
    "print(\"Categories completed:\")\n",
    "for c in catlist:\n",
    "    # For each category:\n",
    "    if not os.path.exists(datapath+'train/'+str(c)):\n",
    "        os.makedirs(datapath+'train/'+str(c))\n",
    "    ims = catdict[c]\n",
    "    for imkey in ims:\n",
    "        # For each image:\n",
    "        path = pather[imkey]\n",
    "        newpath = datapath+'train/'+str(c)+'/'+path[path.rfind('/')+1:-4]+'.bmp'\n",
    "        # Downsize the image to 84x84\n",
    "        with open(path, 'rb') as f:\n",
    "            p = Image.open(f)\n",
    "            w,h = p.size\n",
    "            p = p.convert('RGB')\n",
    "        p = p.resize((84, 84), Image.BILINEAR)\n",
    "        p.save(newpath)\n",
    "        # Downsize the bounding box annotations to 10x10\n",
    "        boxes = annodict[imkey]\n",
    "        boxdict_smaller[newpath] = []\n",
    "        for boxcode in boxes:\n",
    "            box = boxdict[boxcode]\n",
    "            xmin = box[0]\n",
    "            xmax = box[2]+xmin\n",
    "            ymin = box[1]\n",
    "            ymax = box[3]+ymin\n",
    "            boxdict_smaller[newpath].append([xmin*10/w, ymin*10/h, xmax*10/w, ymax*10/h])\n",
    "    count += 1\n",
    "    if count%50 == 0:\n",
    "        print(count)\n",
    "\n",
    "print(\"Representation set complete!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categories completed:\n",
      "50\n",
      "100\n",
      "150\n",
      "200\n",
      "Evaluation set complete!\n"
     ]
    }
   ],
   "source": [
    "# Build evaluation set\n",
    "\n",
    "count = 0\n",
    "catlist = testcatlist\n",
    "print(\"Categories completed:\")\n",
    "for c in catlist:\n",
    "    # For each category:\n",
    "    if not os.path.exists(datapath+'test/'+str(c)):\n",
    "        os.makedirs(datapath+'test/'+str(c))\n",
    "    ims = catdict[c]\n",
    "    for imkey in ims:\n",
    "        # For each image:\n",
    "        path = pather[imkey]\n",
    "        newpath = datapath+'test/'+str(c)+'/'+path[path.rfind('/')+1:-4]+'.bmp'\n",
    "        # Downsize the image to 84x84\n",
    "        with open(path, 'rb') as f:\n",
    "            p = Image.open(f)\n",
    "            w,h = p.size\n",
    "            p = p.convert('RGB')\n",
    "        p = p.resize((84, 84), Image.BILINEAR)\n",
    "        p.save(newpath)\n",
    "        # Downsize the bounding box annotations to 10x10\n",
    "        boxes = annodict[imkey]\n",
    "        boxdict_smaller[newpath] = []\n",
    "        for boxcode in boxes:\n",
    "            box = boxdict[boxcode]\n",
    "            xmin = box[0]\n",
    "            xmax = box[2]+xmin\n",
    "            ymin = box[1]\n",
    "            ymax = box[3]+ymin\n",
    "            boxdict_smaller[newpath].append([xmin*10/w, ymin*10/h, xmax*10/w, ymax*10/h])\n",
    "    count += 1\n",
    "    if count%50 == 0:\n",
    "        print(count)\n",
    "\n",
    "print(\"Evaluation set complete!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categories completed:\n",
      "50\n",
      "100\n",
      "150\n",
      "200\n",
      "Reference images compiled!\n"
     ]
    }
   ],
   "source": [
    "# Compile reference images within the evaluation set\n",
    "\n",
    "count = 0\n",
    "catlist = testcatlist\n",
    "print(\"Categories completed:\")\n",
    "for c in catlist:\n",
    "    # For each category:\n",
    "    if not os.path.exists(datapath+'repr/'+str(c)):\n",
    "        os.makedirs(datapath+'repr/'+str(c))\n",
    "    ims = list(catdict[c])\n",
    "    ims = ims[:len(ims)//5]\n",
    "    for imkey in ims:\n",
    "        # For each image:\n",
    "        path = pather[imkey]\n",
    "        oldpath = datapath+'test/'+str(c)+'/'+path[path.rfind('/')+1:-4]+'.bmp'\n",
    "        newpath = datapath+'repr/'+str(c)+'/'+path[path.rfind('/')+1:-4]+'.bmp'\n",
    "        # Create a softlink to the corresponding evaluation set image\n",
    "        os.symlink(oldpath, newpath)\n",
    "        # Copy over the bounding box annodations from the softlinked image\n",
    "        boxdict_smaller[newpath] = boxdict_smaller[oldpath]\n",
    "    count += 1\n",
    "    if count%50 == 0:\n",
    "        print(count)\n",
    "        \n",
    "print(\"Reference images compiled!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categories completed:\n",
      "50\n",
      "100\n",
      "150\n",
      "200\n",
      "Query images compiled!\n"
     ]
    }
   ],
   "source": [
    "# Compile query images within the evaluation set\n",
    "\n",
    "count = 0\n",
    "catlist = testcatlist\n",
    "print(\"Categories completed:\")\n",
    "for c in catlist:\n",
    "    # For each category:\n",
    "    if not os.path.exists(datapath+'query/'+str(c)):\n",
    "        os.makedirs(datapath+'query/'+str(c))\n",
    "    ims = list(catdict[c])\n",
    "    ims = ims[len(ims)//5:]\n",
    "    for imkey in ims:\n",
    "        # For each image:\n",
    "        path = pather[imkey]\n",
    "        oldpath = datapath+'test/'+str(c)+'/'+path[path.rfind('/')+1:-4]+'.bmp'\n",
    "        newpath = datapath+'query/'+str(c)+'/'+path[path.rfind('/')+1:-4]+'.bmp'\n",
    "        # Create a softlink to the corresponding evaluation set image\n",
    "        os.symlink(oldpath, newpath)\n",
    "        # Copy over the bounding box annodations from the softlinked image\n",
    "        boxdict_smaller[newpath] = boxdict_smaller[oldpath]\n",
    "    count += 1\n",
    "    if count%50 == 0:\n",
    "        print(count)\n",
    "save(boxdict_smaller, 'helpful_files/box_coords.pth')\n",
    "        \n",
    "print(\"Query images compiled!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Your md5 hash is: e11fce3d5ab8929a92ae2dd25a83b2ad\n",
      "\n",
      "Congratulations! Your dataset appears to be a faithful reproduction.\n"
     ]
    }
   ],
   "source": [
    "# Use a recursive md5 checksum to verify that the constructed dataset reproduces the original\n",
    "\n",
    "# NOTE: this code only checks the assignments of photos to categories. \n",
    "# Thus, changing any file or folder names WILL cause the check to fail, even if the images themselves are still correct.\n",
    "\n",
    "import hashlib\n",
    "\n",
    "# Get hashes of relative locations for each photo in the dataset\n",
    "hashlist = []\n",
    "subdirs = ['train/', 'test/', 'repr/', 'query/']\n",
    "for subdir in subdirs:\n",
    "    for cat in os.listdir(datapath+subdir):\n",
    "        hashlist = hashlist + [hashlib.md5((subdir+cat+'/'+file).encode()).hexdigest() \n",
    "                               for file in os.listdir(datapath+subdir+cat)]\n",
    "# Get a hash for the sorted list of hashes\n",
    "hashlist.sort()\n",
    "md5 = hashlib.md5(\"\".join(hashlist).encode()).hexdigest()\n",
    "# Compare\n",
    "print(\"Your md5 hash is:\", md5)\n",
    "print()\n",
    "if not random_assign:\n",
    "    if md5!=\"e11fce3d5ab8929a92ae2dd25a83b2ad\":\n",
    "        print(\"ALERT: Something went wrong. Your dataset does not match the original.\")\n",
    "    else:\n",
    "        print(\"Congratulations! Your dataset appears to be a faithful reproduction.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shut down the notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "Jupyter.notebook.session.delete();\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%javascript\n",
    "Jupyter.notebook.session.delete();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#                           CONCLUDES DOWNSCALED VERSION OF META_INAT DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
